# Subgroup Balancing : Codebook

## (1) Illustrating Examples

## You can check specific processes in process_outline.jpg. 

### [0] Tool functions / Libraries

library(tidyr)
library(dplyr)
library(ggplot2)
library(MASS)
library(ggh4x)
library(Rcpp)
library(RcppArmadillo)
library(RcppParallel)
library(Rmosek)
library(Matrix)
library(RColorBrewer)

tri = function(x) pmax(1-abs(x),0)

# Uniform kernel 
kernel_unif = function(W, w, h = 0.5) {
  W_matrix = as.matrix(W)
  w_vector = as.numeric(w)
  diff = abs(W_matrix - matrix(w_vector, nrow(W_matrix), ncol(W_matrix), byrow = TRUE))
  condition = (diff <= h )
  K = apply(condition, 1, all)
  return(K)
}

# IPW functions
ipw_numerator = function(W, w, Y, A, pi_hat, h = 0.5) {
  K = kernel_unif(W, w, h = h)
  terms = (A * Y / pi_hat) - ((1 - A) * Y / (1 - pi_hat))
  weighted_terms = terms * K
  numerator = mean(weighted_terms)
  return(numerator)
}

ipw_denominator = function(W, w, h = 0.5) {
  K = kernel_unif(W, w, h = h)
  denominator = mean(K)
  return(denominator)
}

ipw_estimate = function(W, Y, A, pi_hat, h = 0.5) {
  n = nrow(W)
  tau_ipw = numeric(n)
  for (i in 1:n) {
    w = as.numeric(W[i, ])
    numerator = ipw_numerator(W, w, Y, A, pi_hat, h = h)
    denominator = ipw_denominator(W, w, h = h)
    if (denominator == 0) {
      cat("Warning: Denominator is zero for i =", i, "\n")
      print(list(
        w = w,
        numerator = numerator,
        denominator = denominator
      ))
    }
    tau_ipw[i] = numerator / denominator
  }
  return(tau_ipw)
}

### [1] Data generation

n = 1000
set.seed(980929)
# Treatment variable : F
F = rbinom(n, 1, 0.5)
# X1
mu = 4 * F - 2
Sigma = 2.0
X1 = rnorm(n, mean = mu, sd = Sigma)
# Gaussian random noises
eps = rnorm(n,0,1); delta = rnorm(n,0,1)
df = data.frame(F,X1,eps,delta) %>%
  mutate(A = as.integer(F + 0.5 * F * X1 + delta > 0)) %>%
  mutate(Y = (X1)^2 + A * tri((X1-2)/2) + (1-A) * tri((X1+2)/2) + eps) %>%
  dplyr::select(F, X1, A, Y)

####### Plot 0 : Verification of Consistency

plotdf = df
plotdf = plotdf %>%
  mutate(
    Y1 = (X1)^2 + tri((X1 - 2) / 2) + eps,  
    Y0 = (X1)^2 + tri((X1 + 2) / 2) + eps,  
    Y_diff = Y1 - Y0  
  )
ggplot(plotdf, aes(x = X1, y = Y_diff)) +
  geom_point(alpha = 0.6, color="red") + 
  geom_hline(yintercept = 0, color = "black", linetype = "solid") +  
  geom_vline(xintercept = 0, color = "black", linetype = "solid") +  
  geom_vline(xintercept = -2, linetype = "dashed", color = "red") + 
  geom_vline(xintercept = 2, linetype = "dashed", color = "red") +   
  scale_x_continuous(breaks = c(-4, -2, 0, 2, 4)) +  
  labs(
    x = expression(X[1]),
    y = expression(Y^1 - Y^0)
  ) +
  theme_minimal()

### [2] IPW method 

W = df %>%
  mutate(sigX1 = log(X1^2)) %>%
  dplyr::select(sigX1,F)

# Propensity score vector

propensity_logit = glm(A ~ ., data = cbind(W, A = df$A), family = binomial)
propensity_score = predict(propensity_logit, type = "response")
# propensity가 0.001~0.999가 나오게
propensity_score = pmax(pmin(propensity_score, 0.999), 0.001)

# IPW ESTIMATE
tau_ipw = ipw_estimate(W, df$Y, df$A, propensity_score, h = 0.5)

finaldf = data.frame(W=W, A=df$A, pi.hat=propensity_score, tau.w=tau_ipw) %>%
  mutate(group=ifelse(tau.w>0,"S1","S2"))

### [3] Visualization

# ==================================================
# Plot 1 : (Imbalance in sensitive variable)
# ==================================================


plot_data = finaldf %>%
  filter(group %in% c("S1", "S2")) %>% 
  dplyr::select(W.F, group) %>%
  mutate(W.F = as.factor(W.F))  

ggplot(plot_data, aes(x = group, fill = W.F)) +
  geom_bar(position = "fill") + 
  scale_y_continuous(labels = scales::percent) +  
  labs(
    title = "Subgroup Fairness Distribution",
    x = "Group",
    y = "Proportion (%)",
    fill = "F"
  ) +
  theme_minimal() +
  theme(legend.position = "right",
        legend.key.size = unit(1.5, "lines"),
        legend.title = element_text(size=14),    
        legend.text = element_text(size=12),
        axis.text.x = element_text(size = 14)) 


# ==================================================
# Plot 2: Density plot of X1 (Subgroup S1 / A = 0,1)
# ==================================================

# Filter for subgroup s1
fig2df = data.frame(df, pi.hat=propensity_score) %>%
  mutate(group = finaldf$group) %>%
  filter(group == "S1") %>%
  dplyr::select(X1,pi.hat,A)

A1 = fig2df[fig2df$A == 1, 1:2]
A0 = fig2df[fig2df$A == 0, 1:2]

A1 = A1 %>%
  mutate(across(X1, ~ .x / pi.hat, .names = "weighted.{.col}"))
weighted.A1 = A1 %>% 
  dplyr::select(weighted.X1 = weighted.X1)
A0 = A0 %>%
  mutate(across(X1, ~ .x / (1 - pi.hat), .names = "weighted.{.col}"))
weighted.A0 = A0 %>% 
  dplyr::select(weighted.X1 = weighted.X1)

plot_data_s1 = bind_rows(
  mutate(weighted.A1, group = "A=1"),
  mutate(weighted.A0, group = "A=0")
)

plot_data_long_s1 = plot_data_s1 %>%
  pivot_longer(cols = starts_with("weighted"), names_to = "variable", values_to = "value") %>%
  mutate(variable = gsub("weighted\\.", "", variable))

plot_data_filtered_s1 = plot_data_long_s1 %>%
  group_by(variable) %>%
  filter(value > quantile(value, 0.02) & value < quantile(value, 0.98)) %>%
  ungroup() %>% mutate(variable = factor(variable, levels = c("X1")))

ggplot(plot_data_filtered_s1, aes(x = value, fill = group, color = group)) +
  geom_density(alpha = 0.3) +
  facet_wrap(~ variable, scales = "free", ncol = 3) +
  scale_fill_brewer(palette = "Set1") +  
  scale_color_brewer(palette = "Set1") +
  labs(
    title = "Covariate Balance for Subgroup 1 (S1)",
    fill = "Group",
    color = "Group"
  ) +
  theme_minimal() +
  theme(legend.position = "bottom",
        legend.key.size = unit(1.5, "lines"),
        legend.title = element_text(size=14),    
        legend.text = element_text(size=12)) 

# ==================================================
# Plot 3: Density plot of X1 (Subgroup S2 / A = 0,1)
# ==================================================

# Filter for subgroup s2
fig3df = data.frame(df, pi.hat=propensity_score) %>% 
  mutate(group = finaldf$group) %>%
  filter(group == "S2") %>%
  dplyr::select(X1,pi.hat,A)

A1_s2 = fig3df[fig3df$A == 1, 1:2]
A0_s2 = fig3df[fig3df$A == 0, 1:2]

A1_s2 = A1_s2 %>%
  mutate(across(X1, ~ .x / pi.hat, .names = "weighted.{.col}"))
weighted.A1_s2 = A1_s2 %>% dplyr::select(weighted.X1 = weighted.X1)
A0_s2 = A0_s2 %>%
  mutate(across(X1, ~ .x / (1 - pi.hat), .names = "weighted.{.col}"))
weighted.A0_s2 = A0_s2 %>% dplyr::select(weighted.X1 = weighted.X1)

plot_data_s2 = bind_rows(
  mutate(weighted.A1_s2, group = "A=1"),
  mutate(weighted.A0_s2, group = "A=0")
)

plot_data_long_s2 = plot_data_s2 %>%
  pivot_longer(cols = starts_with("weighted"), names_to = "variable", values_to = "value") %>%
  mutate(variable = gsub("weighted\\.", "", variable))

plot_data_filtered_s2 = plot_data_long_s2 %>%
  group_by(variable) %>%
  filter(value > quantile(value, 0.02) & value < quantile(value, 0.98)) %>%
  ungroup() %>% mutate(variable = factor(variable, levels = c("X1")))

ggplot(plot_data_filtered_s2, aes(x = value, fill = group, color = group)) +
  geom_density(alpha = 0.3) +
  facet_wrap(~ variable, scales = "free", ncol = 3) +
  scale_fill_brewer(palette = "Set2", direction = -1) +  
  scale_color_brewer(palette = "Set2", direction = -1) +
  labs(
    title = "Covariate Balance for Subgroup 2 (S2)",
    fill = "Group",
    color = "Group"
  ) +
  theme_minimal() +
  theme(legend.position = "bottom",
        legend.key.size = unit(1.5, "lines"),
        legend.title = element_text(size=14),   
        legend.text = element_text(size=12)) 





